-
Notifications
You must be signed in to change notification settings - Fork 162
Allow KD loss in val mode for MLM plugin #331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Asha Anoosheh <[email protected]>
WalkthroughThe Megatron distillation plugin’s forward path now always computes both teacher (no_grad) and student outputs, with concatenation on non-final pipeline stages. The distillation shape-adjustment helper switches to a **kwargs signature and simplified gating, forwarding config/group via **kwargs to shape utilities. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant Plugin as DistillPlugin
participant Teacher
participant Student
participant Pipeline as PipelineStage
Note over Plugin,Pipeline: Previous flow (before)
Caller->>Plugin: _forward(inputs, training=False)
alt not training
Plugin->>Student: forward(inputs)
Student-->>Plugin: student_out
Plugin-->>Caller: student_out
else training / other
Plugin->>Teacher: forward(inputs) (no_grad)
Teacher-->>Plugin: teacher_out
Plugin->>Student: forward(inputs)
Student-->>Plugin: student_out
alt not last stage
Plugin->>Pipeline: concat(teacher_out, student_out)
Pipeline-->>Caller: combined_out
else last stage
Plugin-->>Caller: student_out
end
end
sequenceDiagram
autonumber
participant Caller
participant Plugin as DistillPlugin
participant Teacher
participant Student
participant Pipeline as PipelineStage
Note over Plugin,Pipeline: New flow (after)
Caller->>Plugin: _forward(inputs, training any)
par Compute teacher (no_grad)
Plugin->>Teacher: forward(inputs)
Teacher-->>Plugin: teacher_out
and Compute student
Plugin->>Student: only_student_forward(inputs)
Student-->>Plugin: student_out
end
alt not last pipeline stage
Plugin->>Pipeline: concat(teacher_out, student_out)
Pipeline-->>Caller: combined_out
else last pipeline stage
Plugin-->>Caller: student_out
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/distill/plugins/megatron.py (1)
578-582
: Whitelist kwargs passed to get_tensor_shapes; avoid accidental TypeError.Passing raw
**kwargs
through can leak unexpected keys toget_tensor_shapes(...)
. Filter to the known set before forwarding. Also, both teacher shape calls are identical; compute once and reuse.Example fix (insert before the calls):
allowed = {"seq_length", "micro_batch_size", "decoder_seq_length", "forward_only"} _shape_kwargs = {k: v for k, v in kwargs.items() if k in allowed}Then update the calls:
teacher_shapes = get_tensor_shapes( config=teacher_config, tp_group=tp_group, cp_group=cp_group, **_shape_kwargs ) teacher_recv_tensor_shapes = teacher_shapes teacher_send_tensor_shapes = teacher_shapes
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/distill/plugins/megatron.py
(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/distill/plugins/megatron.py (1)
558-561
: Gating change LGTM; confirm assumptions for non‑interleaved PP.Dropping
forward_only
from the gate matches the goal (enable KD in validation). Please confirm this is only exercised for non‑interleaved PP and that eval runs with PP>1 don’t hit unexpected shape paths with VPP configured.
model: torch.nn.Module | list[torch.nn.Module], **kwargs | ||
) -> Callable | None: | ||
"""Return the function to adjust tensor shapes for Distillation in Megatron-Core's forward pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid breaking API: keep positional args (backward‑compat) and funnel into kwargs.
Switching to only **kwargs
will break any existing positional call sites with TypeError: too many positional arguments
. Keep the old parameters (defaulted/optional), merge them into kwargs
, and deprecate later.
Apply this diff to the signature:
-def get_tensor_shapes_adjust_fn_for_distillation(
- model: torch.nn.Module | list[torch.nn.Module], **kwargs
-) -> Callable | None:
+def get_tensor_shapes_adjust_fn_for_distillation(
+ model: torch.nn.Module | list[torch.nn.Module],
+ seq_length: int | None = None,
+ micro_batch_size: int | None = None,
+ decoder_seq_length: int | None = None,
+ forward_only: bool | None = None,
+ **kwargs,
+) -> Callable | None:
Add this merge shim at the top of the function body:
# Back‑compat: funnel explicit args into kwargs if provided.
if seq_length is not None:
kwargs.setdefault("seq_length", seq_length)
if micro_batch_size is not None:
kwargs.setdefault("micro_batch_size", micro_batch_size)
if decoder_seq_length is not None:
kwargs.setdefault("decoder_seq_length", decoder_seq_length)
if forward_only is not None:
kwargs.setdefault("forward_only", forward_only)
🤖 Prompt for AI Agents
In modelopt/torch/distill/plugins/megatron.py around lines 550-552, the function
signature was changed to only accept **kwargs which breaks backward
compatibility for callers using positional parameters; restore the original
explicit parameters (seq_length, micro_batch_size, decoder_seq_length,
forward_only) as optional/defaulted parameters in the signature, and at the top
of the function body add a back-compat shim that funnels any provided explicit
args into kwargs (using kwargs.setdefault) so existing positional call sites
continue to work; mark these explicit params as deprecated in a comment for
future removal.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #331 +/- ##
==========================================
- Coverage 73.88% 73.82% -0.06%
==========================================
Files 172 172
Lines 17444 17438 -6
==========================================
- Hits 12888 12874 -14
- Misses 4556 4564 +8 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Asha Anoosheh <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: ?
new feature
Overview: ?
Remove restrictions to allow
DistillationModel.compute_kd_loss()
to be called during Megatron validationUsage
# Add a code snippet demonstrating how to use this
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Bug Fixes
Refactor